"""
Wild Bootstrap procedure in FGWAS.

Author: Chao Huang (chaohuang.stat@gmail.com)
Last update: 2017-09-28
"""

import numpy as np
from numpy.linalg import inv
from numpy.linalg import eig
from scipy.stats import chi2
from S2_GSIS import gsis
from stat_label_region import label_region

"""
installed all the libraries above
"""


def wild_bstp(all_zx_mat, inv_q_all_zx, y_design, resy_design, efit_eta, inv_sig_eta, hat_mat, g_num):
    """
        Significant locus-voxel testing procedure

        :param
            all_zx_mat (matrix): projected snp data (t*n)
            inv_q_all_zx (vector): inverse norm of projected snp data (t*1)
            y_design (matrix): imaging response data (response matrix, n*l*m)
            resy_design (matrix): estimated difference between y_design and X*B
            efit_eta (matrix): the estimated of eta (n*l*m)
            inv_sig_eta (matrix): the inverse of estimated covariance matrix of eta (l*m*m)
            hat_mat (matrix): hat matrix (n*n)
            g_num (scalar): number of candidate snps
        :return
            max_pv_bstp (matrix): max log10 transformed local p-values (b*l)
    """

    # Set up
    n, l, m = y_design.shape

    y_mean = y_design-resy_design
    residual = resy_design-efit_eta
    del y_design, resy_design

    l_stat_top = np.zeros((g_num, l))
    res_y_bstp = np.zeros((n, l, m))
    smy_design = np.zeros((n, l, m))

    for nii in range(n):
        rand_sub = np.random.normal(0, 1, 1)
        rand_vex = np.dot(np.atleast_2d(np.random.normal(0, 1, l)).T, np.ones((1, m)))
        res_y_bstp[nii, :, :] = rand_sub*efit_eta[nii, :, :] + rand_vex*residual[nii, :, :]

    y_bstp = y_mean + res_y_bstp

    for mii in range(m):
        smy_design[:, :, mii] = np.dot(hat_mat, y_bstp[:, :, mii])

    const = np.zeros((n, n))
    const_all = np.zeros((n, n * l))
    for lii in range(l):
        const_all[:, (lii * n):((lii + 1) * n)] = np.dot(np.dot(np.squeeze(smy_design[:, lii, :]),
                                                                inv_sig_eta[lii, :, :]),
                                                         np.squeeze(smy_design[:, lii, :]).T)
        const = const + const_all[:, (lii * n):((lii + 1) * n)]/l
    del y_mean, y_bstp, res_y_bstp, smy_design, efit_eta, inv_sig_eta
    w, v = eig(const)
    w = np.real(w)
    w[w < 0] = 0
    w_diag = np.diag(w ** (1 / 2))
    sq_qr_smy_mat = np.dot(np.dot(v, w_diag), v.T)
    sq_qr_smy_mat = np.real(sq_qr_smy_mat)

    g_stat = np.sum(np.dot(all_zx_mat, sq_qr_smy_mat) ** 2, axis=1)*inv_q_all_zx
    indx = np.argsort(-g_stat)
    del const, sq_qr_smy_mat, g_stat

    for gii in range(g_num):
        temp_1 = np.dot(np.atleast_2d(all_zx_mat[indx[gii], :]), const_all)
        temp = temp_1.reshape(l, n)
        l_stat_top[gii, :] = np.squeeze(np.dot(temp, np.atleast_2d(all_zx_mat[indx[gii], :]).T))*inv_q_all_zx[indx[gii]]
    max_l_stat_top = np.max(l_stat_top, axis=1)
    idx_max_stat_bstp = np.argmax(max_l_stat_top)
    l_stat_top1 = np.reshape(l_stat_top[idx_max_stat_bstp, :], newshape=(l, 1))
    del l_stat_top, all_zx_mat, inv_q_all_zx, indx, max_l_stat_top
    # k1 = np.mean(l_stat_top1)
    # k2 = np.var(l_stat_top1)
    # k3 = np.mean((l_stat_top1 - k1) ** 3)
    # a = k3 / (4 * k2)
    # b = k1 - 2 * k2 ** 2 / k3
    # d = 8 * k2 ** 3 / k3 ** 2
    # pv = 1 - chi2.cdf((l_stat_top1 - b) / a, d)
    # pv_log10 = -np.log10(pv)
    # area_top[gii, 0] = label_region(img_size, img_idx, pv_log10, c_alpha)
    # max_pv_bstp[:, 0] = pv_log10

    return l_stat_top1
